{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lRuXm7d4UpIT"
      },
      "outputs": [],
      "source": [
        "# have you installed snn torch?\n",
        "# %pip install snntorch"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qtbidebHU1PP"
      },
      "outputs": [],
      "source": [
        "import snntorch as snn\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "DATADIR = \"/tmp/data\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Yr8wE2ybU4RQ"
      },
      "source": [
        "## Download Dataset using `spikedata` (deprecated)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rhskczZfU3Wo"
      },
      "outputs": [],
      "source": [
        "from snntorch.spikevision import spikedata \n",
        "# note that a default transform is already applied\n",
        "train_ds = spikedata.NMNIST(f\"{DATADIR}/nmnist\", train=True, num_steps=300, dt=1000) # dt is the # of microseconds integrated\n",
        "test_ds = spikedata.NMNIST(f\"{DATADIR}/nmnist\", train=False, num_steps=300, dt=1000)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IUoggS_uVAB3"
      },
      "outputs": [],
      "source": [
        "train_ds "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Download Dataset using `tonic`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "import tonic\n",
        "train_ds = tonic.datasets.NMNIST(save_to=DATADIR, train=True)\n",
        "test_ds = tonic.datasets.NMNIST(save_to=DATADIR, train=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "import tonic\n",
        "import tonic.transforms as transforms\n",
        "\n",
        "sensor_size = tonic.datasets.NMNIST.sensor_size\n",
        "\n",
        "# Denoise removes isolated, one-off events\n",
        "# time_window\n",
        "frame_transform = transforms.Compose([transforms.Denoise(filter_time=10000),\n",
        "                                      transforms.ToFrame(sensor_size=sensor_size,\n",
        "                                                         time_window=1000),\n",
        "                                     ])\n",
        "\n",
        "train_ds = tonic.datasets.NMNIST(save_to=DATADIR, transform=frame_transform, train=True)\n",
        "test_ds = tonic.datasets.NMNIST(save_to=DATADIR, transform=frame_transform, train=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "train_ds"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KKPN98YwVD4V"
      },
      "source": [
        "## Create DataLoader"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aC1cno26VEu8"
      },
      "outputs": [],
      "source": [
        "from torch.utils.data import DataLoader\n",
        "\n",
        "train_dl = DataLoader(train_ds, shuffle=True, batch_size=64)\n",
        "test_dl = DataLoader(test_ds, shuffle=False, batch_size=64)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "print('the number of items in the dataset is', len(train_dl.dataset))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Play with Data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# get a feel for the data\n",
        "i_item = 20000 # random index into a sample\n",
        "data, label = train_dl.dataset[i_item]\n",
        "import torch\n",
        "data = torch.Tensor(data)\n",
        "\n",
        "print('The data sample has size', data.shape)\n",
        "print(f\"in case you're blind AF, the target is: {label}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-n3lDA8PVKMv"
      },
      "source": [
        "## Visualize"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-zNTZOC3VIKg"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import snntorch.spikeplot as splt\n",
        "from IPython.display import HTML, display\n",
        "import numpy as np\n",
        "\n",
        "# flatten on-spikes and off-spikes into one channel\n",
        "# a = (train_dl.dataset[n][0][:, 0] + train_dl.dataset[n][0][:, 1])\n",
        "data = (data>=1).float() # some spikes are equal to 2...\n",
        "a = (data[:, 0, :, :] - data[:, 1, :, :])\n",
        "# a = np.swapaxes(a, 0, -1)\n",
        "#  Plot\n",
        "fig, ax = plt.subplots()\n",
        "anim = splt.animator(a, fig, ax, interval=30, cmap='seismic')\n",
        "HTML(anim.to_html5_video())\n",
        "# anim.save('nmnist_animation.mp4', writer = 'ffmpeg', fps=50)  "
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "NMNIST.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
